{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semi-supervised Community Detection using Graph Neural Networks\n",
    "\n",
    "Almost every computer 101 class starts with a \"Hello World\" example. Like MNIST for deep learning, in graph domain we have the Zachary's Karate Club problem. The karate club is a social network that includes 34 members and documents pairwise links between members who interact outside the club. The club later divides into two communities led by the instructor (node 0) and the club president (node 33). The network is visualized as follows with the color indicating the community.\n",
    "<img src='../asset/karat_club.png' align='center' width=\"400px\" height=\"300px\" />\n",
    "\n",
    "In this tutorial, you will learn:\n",
    "\n",
    "* Formulate the community detection problem as a semi-supervised node classification task.\n",
    "* Build a GraphSAGE model, a popular Graph Neural Network architecture proposed by [Hamilton et al.](https://arxiv.org/abs/1706.02216)\n",
    "* Train the model and understand the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tutorial_utils import setup_tf\n",
    "setup_tf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import tensorflow as tf\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Community detection as node classification\n",
    "\n",
    "The study of community structure in graphs has a long history. Many proposed methods are *unsupervised* (or *self-supervised* by recent definition), where the model predicts the community labels only by connectivity. Recently, [Kipf et al.,](https://arxiv.org/abs/1609.02907) proposed to formulate the community detection problem as a semi-supervised node classification task. With the help of only a small portion of labeled nodes, a GNN can accurately predict the community labels of the others.\n",
    "\n",
    "In this tutorial, we apply Kipf's setting to the Zachery's Karate Club network to predict the community membership, where only the labels of a few nodes are used."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load the graph and node labels as is covered in the [last session](./1_load_data.ipynb). Here, we have provided you a function for loading the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tutorial_utils import load_zachery\n",
    "\n",
    "# ----------- 0. load graph -------------- #\n",
    "g = load_zachery()\n",
    "print(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the original Zachery's Karate Club graph, nodes are feature-less. (The `'Age'` attribute is an artificial one mainly for tutorial purposes). For feature-less graph, a common practice is to use an embedding weight that is updated during training for every node.\n",
    "\n",
    "We can use Keras's `Embedding` module to achieve this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------- 1. node features -------------- #\n",
    "node_embed = tf.keras.layers.Embedding(g.number_of_nodes(), 5,\n",
    "                                       embeddings_initializer='glorot_uniform')  # Every node has an embedding of size 5.\n",
    "node_embed(1) # intialize embedding layer\n",
    "inputs = node_embed.embeddings # the embedding matrix\n",
    "print(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The community label is stored in the `'club'` node feature (0 for instructor, 1 for club president). Only nodes 0 and 33 are labeled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = g.ndata['club']\n",
    "labeled_nodes = [0, 33]\n",
    "print('Labels', tf.gather(labels, labeled_nodes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a GraphSAGE model\n",
    "\n",
    "Our model consists of two layers, each computes new node representations by aggregating neighbor information. The equations are:\n",
    "\n",
    "$$\n",
    "h_{\\mathcal{N}(v)}^k\\leftarrow \\text{AGGREGATE}_k\\{h_u^{k-1},\\forall u\\in\\mathcal{N}(v)\\}\n",
    "$$\n",
    "\n",
    "$$\n",
    "h_v^k\\leftarrow \\sigma\\left(W^k\\cdot \\text{CONCAT}(h_v^{k-1}, h_{\\mathcal{N}(v)}^k) \\right)\n",
    "$$\n",
    "\n",
    "DGL provides implementation of many popular neighbor aggregation modules. They all can be invoked easily with one line of codes. See the full list of supported [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.nn import SAGEConv\n",
    "\n",
    "# ----------- 2. create model -------------- #\n",
    "# build a two-layer GraphSAGE model\n",
    "class GraphSAGE(tf.keras.layers.Layer):\n",
    "    def __init__(self, in_feats, h_feats, num_classes):\n",
    "        super(GraphSAGE, self).__init__()\n",
    "        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')\n",
    "        self.conv2 = SAGEConv(h_feats, num_classes, 'mean')\n",
    "    \n",
    "    def call(self, g, in_feat):\n",
    "        h = self.conv1(g, in_feat)\n",
    "        h = tf.nn.relu(h)\n",
    "        h = self.conv2(g, h)\n",
    "        return h\n",
    "    \n",
    "# Create the model with given dimensions \n",
    "# input layer dimension: 5, node embeddings\n",
    "# hidden layer dimension: 16\n",
    "# output layer dimension: 2, the two classes, 0 and 1\n",
    "net = GraphSAGE(5, 16, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------- 3. set up loss and optimizer -------------- #\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
    "loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "    from_logits=True)\n",
    "\n",
    "# ----------- 4. training -------------------------------- #\n",
    "all_logits = []\n",
    "for e in range(100):\n",
    "    \n",
    "    with tf.GradientTape() as tape:\n",
    "        tape.watch(inputs) # optimize embedding layer also\n",
    "        \n",
    "        # forward\n",
    "        logits = net(g, inputs)\n",
    "\n",
    "        # compute loss\n",
    "        loss = loss_fcn(tf.gather(labels, labeled_nodes), \n",
    "                        tf.gather(logits, labeled_nodes))\n",
    "\n",
    "        # backward\n",
    "        grads = tape.gradient(loss, net.trainable_weights + node_embed.trainable_weights)        \n",
    "        optimizer.apply_gradients(zip(grads, net.trainable_weights + node_embed.trainable_weights))\n",
    "        all_logits.append(logits.numpy())\n",
    "    \n",
    "    if e % 5 == 0:\n",
    "        print('In epoch {}, loss: {}'.format(e, loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------- 5. check results ------------------------ #\n",
    "pred = tf.argmax(logits, axis=1).numpy()\n",
    "print('Accuracy', (pred == labels.numpy()).sum().item() / len(pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the result\n",
    "\n",
    "Since the GNN produces a logit vector of size 2 for each array. We can plot to a 2-D plane.\n",
    "\n",
    "<img src='../asset/gnn_ep0.png' align='center' width=\"400px\" height=\"300px\"/>\n",
    "<img src='../asset/gnn_ep_anime.gif' align='center' width=\"400px\" height=\"300px\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the following code to visualize the result. Require ffmpeg."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A bit of setup, just ignore this cell\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# for auto-reloading external modules\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (4.0, 3.0) # set default size of plots\n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'\n",
    "plt.rcParams['animation.html'] = 'html5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the node classification using the logits output. Requires ffmpeg.\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.animation as animation\n",
    "from IPython.display import HTML\n",
    "\n",
    "fig = plt.figure(dpi=150)\n",
    "fig.clf()\n",
    "ax = fig.subplots()\n",
    "nx_G = g.to_networkx()\n",
    "def draw(i):\n",
    "    cls1color = '#00FFFF'\n",
    "    cls2color = '#FF00FF'\n",
    "    pos = {}\n",
    "    colors = []\n",
    "    for v in range(34):\n",
    "        pred = all_logits[i].numpy()\n",
    "        pos[v] = pred[v]\n",
    "        cls = labels[v]\n",
    "        colors.append(cls1color if cls else cls2color)\n",
    "    ax.cla()\n",
    "    ax.axis('off')\n",
    "    ax.set_title('Epoch: %d' % i)\n",
    "    nx.draw(nx_G.to_undirected(), pos, node_color=colors, with_labels=True, node_size=200)\n",
    "\n",
    "ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)\n",
    "HTML(ani.to_html5_video())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise\n",
    "\n",
    "Play with the GNN models by using other [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dgl]",
   "language": "python",
   "name": "conda-env-dgl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
